{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adaptive Rounding (AdaRound)\n",
    "This notebook contains a working example of AIMET adaptive rounding (AdaRound).\n",
    "\n",
    "AIMET quantization features typically use the \"nearest rounding\" technique for achieving quantization.\n",
    "When using the nearest rounding technique, the weight value is quantized to the nearest integer value.\n",
    "\n",
    "AdaRound optimizes a loss function using unlabeled training data to decide whether to quantize a specific weight to the closer integer value or the farther one.\n",
    "Using AdaRound, quantized accuracy is closer to the FP32 model than with nearest rounding.\n",
    "\n",
    "## Overall flow\n",
    "\n",
    "The example follows these high-level steps:\n",
    "\n",
    "1. Instantiate the example evaluation and training pipeline\n",
    "2. Load the FP32 model and evaluate the model to find the baseline FP32 accuracy\n",
    "3. Create a quantization simulation model (with fake quantization ops) and evaluate the quantized simuation model\n",
    "4. Apply AdaRound and evaluate the simulation model to get a post-finetuned quantized accuracy score\n",
    "\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "Note\n",
    "\n",
    "This notebook does not show state-of-the-art results. For example, it uses a relatively quantization-friendly model (Resnet18). Also, some optimization parameters like number of fine-tuning epochs are chosen to improve execution speed in the notebook.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Dataset\n",
    "\n",
    "This example does image classification on the ImageNet dataset. If you already have a version of the data set, use that. Otherwise download the data set, for example from https://image-net.org/challenges/LSVRC/2012/index .\n",
    "\n",
    "\n",
    "</div>\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "Note\n",
    "\n",
    "To speed up the execution of this notebook, you can use a reduced subset of the ImageNet dataset. For example: The entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. However, for the purpose of running this notebook, you can reduce the dataset to, say, two samples per class.\n",
    "\n",
    "</div>\n",
    "\n",
    "Edit the cell below to specify the directory where the downloaded ImageNet dataset is saved."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_DIR = '/path/to/dataset/'         # Replace this path with a real directory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 1. Instantiate the example training and validation pipeline\n",
    "\n",
    "**Use the following training and validation loop for the image classification task.**\n",
    "\n",
    "Things to note:\n",
    "\n",
    "- AIMET does not put limitations on how the evaluation pipeline is written. AIMET creates an `onnxruntime.InferenceSession` for the quantized model, which can be run like a regular `InferenceSession`. `sim.session` can be used in place of the any other `InferenceSession` when doing inference/evaluation.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "import torch\n",
    "import onnxruntime as ort\n",
    "import numpy as np\n",
    "\n",
    "BATCH_SIZE = 32\n",
    "NUM_CALIBRATION_SAMPLES = 256\n",
    "NUM_EVAL_SAMPLES = 50000\n",
    "\n",
    "preprocess = transforms.Compose(\n",
    "    [\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "imagenet_data = torchvision.datasets.ImageNet(DATASET_DIR,\n",
    "                                              split=\"val\",\n",
    "                                              transform=preprocess)\n",
    "\n",
    "dataloader = torch.utils.data.DataLoader(imagenet_data,\n",
    "                                         batch_size=BATCH_SIZE,\n",
    "                                         shuffle=True,\n",
    "                                         num_workers=4)\n",
    "\n",
    "def evaluate(session: ort.InferenceSession):\n",
    "    correct_predictions = 0\n",
    "    total_samples = 0\n",
    "    for inputs, labels in tqdm(dataloader):\n",
    "        inputs, labels = inputs.numpy(), labels.numpy()\n",
    "        input_name = session.get_inputs()[0].name\n",
    "        pred_probs, *_ = session.run(None, {input_name: inputs})\n",
    "        pred_labels = np.argmax(pred_probs, axis=1)\n",
    "        correct_predictions += np.sum(pred_labels == labels)\n",
    "        total_samples += labels.shape[0]\n",
    "    return correct_predictions / total_samples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## 2. Convert an FP32 PyTorch model to ONNX, simplify & then evaluate baseline FP32 accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**2.1 Export a pretrained resnet18 model to onnx**\n",
    "\n",
    "You can load any pretrained PyTorch model instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import resnet18\n",
    "import onnx\n",
    "\n",
    "input_shape = (1, 3, 224, 224)    # Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)\n",
    "dummy_input = torch.randn(input_shape)\n",
    "filename = \"./resnet18.onnx\"\n",
    "\n",
    "# Load a pretrained ResNet-18 model in torch\n",
    "pt_model = resnet18(pretrained=True)\n",
    "\n",
    "# Export the torch model to onnx\n",
    "torch.onnx.export(pt_model.eval(),\n",
    "                  dummy_input,\n",
    "                  filename,\n",
    "                  input_names=['input'],\n",
    "                  output_names=['output'],\n",
    "                  dynamic_axes={\n",
    "                      'input' : {0 : 'batch_size'},\n",
    "                      'output' : {0 : 'batch_size'},\n",
    "                  }\n",
    "                  )\n",
    "\n",
    "model = onnx.load_model(filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "**2.2 (Optional) Simplify the onnx model**\n",
    "\n",
    "It is recommended to simplify the model before using AIMET as it can improve quantized accuracy and runtime performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxsim import simplify\n",
    "model, _ = simplify(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "**2.3 Decide whether to place the model on a CPU or CUDA device**\n",
    "\n",
    "This example uses CUDA if it is available. You can change this logic and force a device placement if needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cudnn_conv_algo_search is fixing it to default to avoid changing in accuracies/outputs at every inference\n",
    "if 'CUDAExecutionProvider' in ort.get_available_providers():\n",
    "    providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']\n",
    "else:\n",
    "    providers = ['CPUExecutionProvider']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "**2.4 Create an InferenceSession and determine the model's FP32 accuracy**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = ort.InferenceSession(model.SerializeToString(), providers=providers)\n",
    "accuracy = evaluate(sess)\n",
    "print(f\"FP32 model accuracy: {accuracy}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 3. Create a quantization simulation model and determine quantized accuracy\n",
    "\n",
    "\n",
    "**3.1 Fold BatchNormalization layers**\n",
    "\n",
    "Before calculating the simulated quantized accuracy using QuantizationSimModel, fold the BatchNormalization (BN) layers into adjacent Convolutional layers. The BN layers that cannot be folded are left as they are.\n",
    "\n",
    "BN folding improves inference performance on quantized runtimes but can degrade accuracy on these platforms. This step simulates this on-target drop in accuracy. \n",
    "\n",
    "\n",
    "Use the following code to call AIMET to fold the BN layers in-place on the given model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_onnx.batch_norm_fold import fold_all_batch_norms_to_weight\n",
    "\n",
    "fold_all_batch_norms_to_weight(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "**3.2 Create a QuantizationSimModel**\n",
    "\n",
    " In this step, AIMET inserts fake quantization ops in the model graph and configures them.\n",
    "\n",
    "Key parameters:\n",
    "\n",
    "- Setting **activation_type** to int8 performs all activation quantizations in the model using integer 8-bit precision\n",
    "- Setting **param_type** to int8 performs all parameter quantizations in the model using integer 8-bit precision\n",
    "\n",
    "See [QuantizationSimModel in the AIMET API documentation](https://quic.github.io/aimet-pages/AimetDocs/api_docs/torch_quantsim.html#aimet_torch.quantsim.QuantizationSimModel.compute_encodings) for a full explanation of the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import aimet_onnx\n",
    "from aimet_common.defs import QuantScheme\n",
    "from aimet_onnx.quantsim import QuantizationSimModel\n",
    "\n",
    "sim = QuantizationSimModel(model=copy.deepcopy(model),\n",
    "                           quant_scheme=QuantScheme.min_max,\n",
    "                           param_type=aimet_onnx.int8,\n",
    "                           activation_type=aimet_onnx.int8,\n",
    "                           providers=providers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "AIMET has added quantizer nodes to the model graph, but before the sim model can be used for inference or training, scale and offset quantization parameters must be calculated for each quantizer node by passing unlabeled data samples through the model to collect range statistics. This process is sometimes referred to as calibration. AIMET refers to it as \"computing encodings\".\n",
    "\n",
    "---\n",
    "**3.3 Pass unlabeled data samples through the model**\n",
    "\n",
    "The following code is one way get unlabeled samples for calibration. It uses the existing pytorch train or validation data loader and converts samples to an onnxruntime-compatible format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "input_name = model.graph.input[0].name\n",
    "num_batches = NUM_CALIBRATION_SAMPLES // BATCH_SIZE\n",
    "onnx_data = [{input_name: data.numpy()} for data, labels in itertools.islice(dataloader, num_batches)]\n",
    "\n",
    "sim.compute_encodings(onnx_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A few notes regarding the data samples:\n",
    "\n",
    "- A very small percentage of the data samples are needed. For example, the training dataset for ImageNet has 1M samples; 500 or 1000 suffice to compute encodings.\n",
    "- The samples should be reasonably well distributed. While it's not necessary to cover all classes, avoid extreme scenarios like using only dark or only light samples. That is, using only pictures captured at night, say, could skew the results.\n",
    "\n",
    "---\n",
    "**3.4 Evaluate the quantized model**\n",
    "\n",
    "You can pass `sim.session` to the eval function to evaluate the quantsim model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Evaluate the pre-adaround model\n",
    "accuracy = evaluate(sim.session)\n",
    "print(f\"Pre-adaround sim accuracy {accuracy}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## 4. Apply AdaRound\n",
    "\n",
    "**4.1 Run AdaRound optimization**\n",
    "\n",
    "Some key parameters:\n",
    "\n",
    "- **inputs:**  is a collection (e.g., `List[Dict[str, np.ndarray]]`) of InferenceSession inputs for the model. AdaRound needs a dataloader in order to use data samples to learn the rounding vectors.\n",
    "- **iterations:** is the number of iterations to apply to each layer. Default value is 10000, and we strongly recommend using at least this number. This example uses 32 to speed up execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Apply AdaRound to the model weights\n",
    "aimet_onnx.apply_adaround(sim, onnx_data, iterations=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**4.2 Recompute activation encodings**\n",
    "\n",
    "In this case, activation encodings were already computed in Step 3.3. However, since AdaRound-modified weights may affect the distribution of activations in the model, it is recommended to recompute the activation encodings after applying AdaRound."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Recompute activation encodings (weight encodings are frozen)\n",
    "sim.compute_encodings(onnx_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**4.3 Evaluate the optimized sim**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the ada-rounded model\n",
    "accuracy = evaluate(sim.session)\n",
    "print(f\"Post-adaround sim accuracy: {accuracy}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There might be little gain in accuracy after this limited application of AdaRound. Experiment with the hyper-parameters to get better results.\n",
    "\n",
    "---\n",
    "\n",
    "## Next steps\n",
    "\n",
    "**Export the model and encodings.**\n",
    "\n",
    "- Export the model with the updated weights but without the fake quant ops. \n",
    "- Export the encodings (scale and offset quantization parameters). AIMET QuantizationSimModel provides an export API for this purpose.\n",
    "\n",
    "The following code performs these exports."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sim.export(path='.', filename_prefix='resnet18_after_adaround')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## For more information\n",
    "\n",
    "See the [AIMET API docs](https://quic.github.io/aimet-pages/AimetDocs/api_docs/index.html) for details about the AIMET APIs and optional parameters.\n",
    "\n",
    "See the [other example notebooks](https://github.com/quic/aimet/tree/develop/Examples/torch/quantization) to learn how to use other AIMET post-training quantization techniques."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
